{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DeezyMatch, example_001"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a new model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from DeezyMatch import train as dm_train\n",
    "\n",
    "# train a new model\n",
    "dm_train(input_file_path=\"../inputs/input_dfm_notebook_001.yaml\", \n",
    "         dataset_path=\"../dataset/dataset-string-matching_train.txt\", \n",
    "         model_name=\"test001\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DeezyMatch import plot_log\n",
    "\n",
    "# plot log file\n",
    "plot_log(path2log=\"./models/test001/log.txt\", \n",
    "         output_name=\"t001\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune a pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from DeezyMatch import finetune as dm_finetune\n",
    "\n",
    "# fine-tune a pretrained model stored at pretrained_model_path and pretrained_vocab_path \n",
    "dm_finetune(input_file_path=\"../inputs/input_dfm_notebook_001.yaml\", \n",
    "            dataset_path=\"../dataset/dataset-string-matching_finetune.txt\", \n",
    "            model_name=\"finetuned_test001\",\n",
    "            pretrained_model_path=\"./models/test001/test001.model\", \n",
    "            pretrained_vocab_path=\"./models/test001/test001.vocab\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from DeezyMatch import inference as dm_inference\n",
    "\n",
    "# model inference using a model stored at pretrained_model_path and pretrained_vocab_path \n",
    "dm_inference(input_file_path=\"../inputs/input_dfm_notebook_001.yaml\",\n",
    "             dataset_path=\"../dataset/dataset-string-matching_test.txt\", \n",
    "             pretrained_model_path=\"./models/finetuned_test001/finetuned_test001.model\", \n",
    "             pretrained_vocab_path=\"./models/finetuned_test001/finetuned_test001.vocab\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate query vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DeezyMatch import inference as dm_inference\n",
    "\n",
    "# generate vectors for queries (specified in dataset_path) \n",
    "# using a model stored at pretrained_model_path and pretrained_vocab_path \n",
    "dm_inference(input_file_path=\"../inputs/input_dfm_notebook_001.yaml\",\n",
    "            dataset_path=\"../dataset/dataset-queries.txt\", \n",
    "            pretrained_model_path=\"./models/finetuned_test001/finetuned_test001.model\", \n",
    "            pretrained_vocab_path=\"./models/finetuned_test001/finetuned_test001.vocab\",\n",
    "            inference_mode=\"vect\",\n",
    "            scenario=\"queries/test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate candidate vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DeezyMatch import inference as dm_inference\n",
    "\n",
    "# generate vectors for candidates (specified in dataset_path) \n",
    "# using a model stored at pretrained_model_path and pretrained_vocab_path \n",
    "dm_inference(input_file_path=\"../inputs/input_dfm_notebook_001.yaml\",\n",
    "             dataset_path=\"../dataset/dataset-candidates.txt\", \n",
    "             pretrained_model_path=\"./models/finetuned_test001/finetuned_test001.model\", \n",
    "             pretrained_vocab_path=\"./models/finetuned_test001/finetuned_test001.vocab\",\n",
    "             inference_mode=\"vect\",\n",
    "             scenario=\"candidates/test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Assembling queries vector representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from DeezyMatch import combine_vecs\n",
    "\n",
    "# combine vectors stored in queries/test and save them in combined/queries_test\n",
    "combine_vecs(rnn_passes=['fwd', 'bwd'], \n",
    "             input_scenario='queries/test', \n",
    "             output_scenario='combined/queries_test', \n",
    "             print_every=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Assembling candidates vector representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from DeezyMatch import combine_vecs\n",
    "\n",
    "# combine vectors stored in candidates/test and save them in combined/candidates_test\n",
    "combine_vecs(rnn_passes=['fwd', 'bwd'], \n",
    "             input_scenario='candidates/test', \n",
    "             output_scenario='combined/candidates_test', \n",
    "             print_every=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Candidate Ranker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from DeezyMatch import candidate_ranker\n",
    "\n",
    "# Select candidates based on L2-norm distance (aka faiss distance):\n",
    "# find candidates from candidate_scenario \n",
    "# for queries specified in query_scenario\n",
    "candidates_pd = \\\n",
    "    candidate_ranker(query_scenario=\"./combined/queries_test\",\n",
    "                     candidate_scenario=\"./combined/candidates_test\", \n",
    "                     ranking_metric=\"faiss\", \n",
    "                     selection_threshold=5., \n",
    "                     num_candidates=2, \n",
    "                     search_size=2, \n",
    "                     output_path=\"ranker_results/test_candidates_deezymatch\", \n",
    "                     pretrained_model_path=\"./models/finetuned_test001/finetuned_test001.model\", \n",
    "                     pretrained_vocab_path=\"./models/finetuned_test001/finetuned_test001.vocab\", \n",
    "                     number_test_rows=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "candidates_pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DeezyMatch import candidate_ranker\n",
    "\n",
    "# Select candidates based on L2-norm distance (aka faiss distance):\n",
    "# find candidates from candidate_scenario \n",
    "# for queries specified in query_scenario\n",
    "candidates_pd = \\\n",
    "    candidate_ranker(query_scenario=\"./combined/queries_test\",\n",
    "                     candidate_scenario=\"./combined/candidates_test\", \n",
    "                     ranking_metric=\"cosine\", \n",
    "                     selection_threshold=0.9, \n",
    "                     num_candidates=2, \n",
    "                     search_size=2, \n",
    "                     output_path=\"ranker_results/test_candidates_deezymatch_cosine\", \n",
    "                     pretrained_model_path=\"./models/finetuned_test001/finetuned_test001.model\", \n",
    "                     pretrained_vocab_path=\"./models/finetuned_test001/finetuned_test001.vocab\", \n",
    "                     number_test_rows=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "candidates_pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Candidate ranking on-the-fly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from DeezyMatch import candidate_ranker\n",
    "\n",
    "# Ranking on-the-fly\n",
    "# find candidates from candidate_scenario \n",
    "# for queries specified by the `query` argument\n",
    "candidates_pd = \\\n",
    "    candidate_ranker(candidate_scenario=\"./combined/candidates_test\",\n",
    "                     query=[\"DeezyMatch\", \"kasra\", \"fede\", \"mariona\"],\n",
    "                     ranking_metric=\"faiss\", \n",
    "                     selection_threshold=5., \n",
    "                     num_candidates=1, \n",
    "                     search_size=100, \n",
    "                     output_path=\"ranker_results/test_candidates_deezymatch_on_the_fly\", \n",
    "                     pretrained_model_path=\"./models/finetuned_test001/finetuned_test001.model\", \n",
    "                     pretrained_vocab_path=\"./models/finetuned_test001/finetuned_test001.vocab\", \n",
    "                     number_test_rows=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "candidates_pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The candidate ranker can be initialised, to be used multiple times, by running:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DeezyMatch import candidate_ranker_init\n",
    "\n",
    "# initializing candidate_ranker via candidate_ranker_init\n",
    "myranker = candidate_ranker_init(candidate_scenario=\"./combined/candidates_test\",\n",
    "                                 query=[\"DeezyMatch\", \"kasra\", \"fede\", \"mariona\"],\n",
    "                                 ranking_metric=\"faiss\", \n",
    "                                 selection_threshold=5., \n",
    "                                 num_candidates=1, \n",
    "                                 search_size=100, \n",
    "                                 output_path=\"ranker_results/test_candidates_deezymatch_on_the_fly\", \n",
    "                                 pretrained_model_path=\"./models/finetuned_test001/finetuned_test001.model\", \n",
    "                                 pretrained_vocab_path=\"./models/finetuned_test001/finetuned_test001.vocab\", \n",
    "                                 number_test_rows=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print the content of myranker by:\n",
    "print(myranker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# To rank the queries:\n",
    "myranker.rank()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#The results are stored in:\n",
    "myranker.output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Change the queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myranker.set_query(query=[\"khan\", \"feng\", \"cheng\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# To rank the queries:\n",
    "myranker.rank()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#The results are stored in:\n",
    "myranker.output"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
